Skip to content

Conversation

markurtz
Copy link
Collaborator

@markurtz markurtz commented Oct 1, 2025

…icated combinations

Summary

Details

  • [ ]

Test Plan

Related Issues

  • Resolves #

  • "I certify that all code in this PR is my own, except as noted below."

Use of AI

  • Includes AI-assisted code completion
  • Includes code generated by an AI application
  • Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes ## WRITTEN BY AI ##)

@markurtz markurtz requested a review from sjmonson October 1, 2025 12:06
@markurtz markurtz self-assigned this Oct 1, 2025
@markurtz markurtz changed the title Initial state for datasets rework to enable multimodal and more compl… [GuideLLM Refactor] Data pipelines rework and multimodal support Oct 1, 2025
sjmonson and others added 2 commits October 3, 2025 14:00
## TODO

- Docs
- ~CSV arg string support~ CSV arg string now supports single bucket
(see last example). Might leave it at that for now.
- More validation

## Summary

<!--
Include a short paragraph of the changes introduced in this PR.
If this PR requires additional context or rationale, explain why
the changes are necessary.
-->

This PR is a port of #287 to the v0.4.0 refactor branch.

Adds controls for sharing one or more fixed prefixes between samples.
See examples bellow.

## Details

<!--
Provide a detailed list of all changes introduced in this pull request.
-->

Adds a `prefix_buckets` argument to the `SyntheticTextDatasetConfig`,
each bucket consists of a prefix count, token count, and bucket weight.
Prefix count sets the number of unique prefixes to generate for a given
bucket, token count is the length of each prompt in the bucket, and
bucket weight is used to calculate the proportion of requests the bucket
applies to relative to the sum of all bucket weights. Here are a few
examples:


Here we have one bucket of 32 prefixes of length 2048. Since there are
1024 total samples each prefix will apply to 32 samples. If there is
only one bucket than weight can be omitted as the bucket applies to 100%
of samples.

```yaml
data:
  prefix_buckets:
    - prefix_tokens: 2048
      prefix_count: 32
  prompt_tokens: 256
  output_tokens: 256
  samples: 1024
```

In this modified version of the first example 16 of the prompts have
2048 tokens while the other 16 have 1024 tokens.

```yaml
data:
  prefix_buckets:
    - prefix_tokens: 2048
      prefix_count: 16
      bucket_weight: 50
    - prefix_tokens: 1024
      prefix_count: 16
      bucket_weight: 50
  prompt_tokens: 256
  output_tokens: 256
  samples: 1024
```

The prefix tokens of a bucket can also be 0 to disable prefixes for
those samples. Here is an example where 40% of the samples have a prefix
of 2048 tokens while the other 60% have no prefix.

```yaml
data:
  prefix_buckets:
    - prefix_tokens: 2048
      bucket_weight: 40
    - prefix_tokens: 0
      bucket_weight: 60
  prompt_tokens: 256
  output_tokens: 256
  samples: 1000
```

If only a single bucket is needed, it can be set at the top level. This
make the changes backwards compatible with the previous interface and
allows the CSV string format to work without parsing nested structures
(at least for this use-case).

```yaml
data:
  prefix_tokens: 128
  prefix_count: 10
  prompt_tokens: 256
  output_tokens: 256
  samples: 1000
```

## Test Plan

<!--
List the steps needed to test this PR.
-->
- PR includes unit tests for all synthetic dataset changes (`pytest
tests/unit/dataset`)
- Scenearios in the Details section can be used against a model server
with prefix caching and the cache rate can be confirmed by inspecting
console output.

## Related Issues

<!--
Link any relevant issues that this PR addresses.
-->
- Resolves #232
- Closes #287

---

- [x] "I certify that all code in this PR is my own, except as noted
below."

## Use of AI

- [x] Includes AI-assisted code completion
- [ ] Includes code generated by an AI application
- [x] Includes AI-generated tests (NOTE: AI written tests should have a
docstring that includes `## WRITTEN BY AI ##`)

---------

Signed-off-by: Samuel Monson <[email protected]>
Copy link
Collaborator

@jaredoconnell jaredoconnell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just leaving a comment since this is in a pre-release state. I didn't see any major problems with the code after one pass.
I ran into some errors running the example command you sent to the group earlier, and I messaged you about that.
It definitely needs tests, and some more comments, and some more documentation and examples. Example commands in the doc will make it easier to test this PR.

Comment on lines 49 to 65
args_dict = args if isinstance(args, dict) else args.model_dump()
combined["url"] = args_dict.get("url", combined.get("url"))
combined["path"] = args_dict.get("path", combined.get("path"))
combined["method"] = args_dict.get("method", combined.get("method"))
combined["stream"] = args_dict.get("stream", combined.get("stream"))
combined["content_body"] = args_dict.get(
"content_body", combined.get("content_body")
)

if (json_body := args_dict.get("json_body")) is not None:
combined["json_body"] = combined.get("json_body", {}) + json_body
if (files := args_dict.get("files")) is not None:
combined["files"] = combined.get("files", {}) + files
if (params := args_dict.get("params")) is not None:
combined["params"] = combined.get("params", {}) + params
if (headers := args_dict.get("headers")) is not None:
combined["headers"] = combined.get("headers", {}) + headers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could could be simplified with loops, if desired.

Comment on lines 52 to 59
open_ai_paths: dict[str, str] = {
"health": "health",
"models": "v1/models",
"text_completions": "v1/completions",
"chat_completions": "v1/chat/completions",
"audio_transcriptions": "v1/audio/transcriptions",
"audio_translations": "v1/audio/translations",
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually we should make the v1 part of the endpoint separately configurable (see #369) but that can be a followup.

Comment on lines 292 to 327
async for chunk in stream.aiter_bytes():
if not chunk or end_reached:
continue
buffer.extend(chunk)

while (start := buffer.find(b"data:")) != -1 and (
end := buffer.find(b"\n", start)
) != -1:
line = buffer[start + len(b"data:") : end].strip()
buffer = buffer[end + 1 :]

if not line:
continue

if line == b"[DONE]":
if request_info.request_timings.request_end is None:
request_info.request_timings.request_end = time.time()
end_reached = True
break

data = (
json.loads(line) if not HAS_ORJSON else orjson.loads(line)
)

if "usage" in data and data["usage"] is not None:
request_info.request_timings.request_end = time.time()
prompt_stats, output_stats = self._extract_response_stats(
data, request
)
else:
if request_info.request_timings.first_iteration is None:
request_info.request_timings.first_iteration = (
time.time()
)
request_info.request_timings.last_iteration = time.time()
deltas.append(self._extract_response_text(data))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on the path taken, time.time is called at different points in the processing.

Call time.time at the beginning of each loop and assign it as appropriate to request_info.request_timings.request_end, request_info.request_timings.first_iteration and/or request_info.request_timings.last_iteration.

Comment on lines 354 to 373
def _extract_response_text(self, data: dict) -> str:
if not data:
return None

def _check_in_process(self):
if not self._in_process or self._async_client is None:
raise RuntimeError(
"Backend not started up for process, cannot process requests."
)
object_type = data.get("object") or data.get("type")

def _get_headers(self) -> dict[str, str]:
return {
"Content-Type": "application/json",
**self.headers,
}
if object_type == "text_completion":
return data.get("choices", [{}])[0].get("text", "")

def _get_params(self, endpoint_type: str) -> dict[str, str]:
if endpoint_type in self.extra_query:
return copy.deepcopy(self.extra_query[endpoint_type])
return copy.deepcopy(self.extra_query)
if object_type == "chat.completion":
return data.get("choices", [{}])[0].get("message", {}).get("content", "")

def _get_chat_messages(
self,
content: Union[
str,
list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]],
Any,
],
) -> list[dict[str, Any]]:
if isinstance(content, str):
return [{"role": "user", "content": content}]

if not isinstance(content, list):
raise ValueError(f"Unsupported content type: {type(content)}")

resolved_content = []
for item in content:
if isinstance(item, dict):
resolved_content.append(item)
elif isinstance(item, str):
resolved_content.append({"type": "text", "text": item})
elif isinstance(item, (Image.Image, Path)):
resolved_content.append(self._get_chat_message_media_item(item))
else:
raise ValueError(f"Unsupported content item type: {type(item)}")

return [{"role": "user", "content": resolved_content}]

def _get_chat_message_media_item(
self, item: Union[Path, Image.Image]
) -> dict[str, Any]:
if isinstance(item, Image.Image):
encoded = base64.b64encode(item.tobytes()).decode("utf-8")
return {
"type": "image",
"image": {"url": f"data:image/jpeg;base64,{encoded}"},
}
if object_type == "chat.completion.chunk":
return data.get("choices", [{}])[0].get("delta", {}).get("content", "")

# Handle file paths
suffix = item.suffix.lower()
if suffix in [".jpg", ".jpeg"]:
image = Image.open(item)
encoded = base64.b64encode(image.tobytes()).decode("utf-8")
return {
"type": "image",
"image": {"url": f"data:image/jpeg;base64,{encoded}"},
}
elif suffix == ".wav":
encoded = base64.b64encode(item.read_bytes()).decode("utf-8")
return {
"type": "input_audio",
"input_audio": {"data": encoded, "format": "wav"},
}
else:
raise ValueError(f"Unsupported file type: {suffix}")
if "text" in data:
return data.get("text", "")

def _get_body(
self,
endpoint_type: str,
request_kwargs: Optional[dict[str, Any]],
max_output_tokens: Optional[int] = None,
**kwargs,
) -> dict[str, Any]:
# Start with endpoint-specific extra body parameters
extra_body: dict = self.extra_body.get(endpoint_type, self.extra_body)

body = copy.deepcopy(extra_body)
body.update(request_kwargs or {})
body.update(kwargs)
body["model"] = self.model

# Handle token limits
max_tokens = max_output_tokens or self.max_output_tokens
if max_tokens is not None:
body.update(
{
"max_tokens": max_tokens,
"max_completion_tokens": max_tokens,
}
)
# Set stop conditions only for request-level limits
if max_output_tokens:
body.update({"stop": None, "ignore_eos": True})
if "delta" in data:
return data.get("delta", "")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After we rebase this lets make it a match statement. I checked the bytecode and we avoid a few load ops.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants